{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "801db364",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import clip\n",
    "import coremltools as ct\n",
    "import numpy as np\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d63f93de",
   "metadata": {},
   "source": [
    "# 1. Load ViT-B/32 CLIP model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "9ebc2db9",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=\"cpu\"\n",
    "model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
    "text = clip.tokenize(\"a diagram\").to(device)\n",
    "i = Image.open(\"IMG_7466.jpg\")\n",
    "image = preprocess(i).unsqueeze(0).to(device)\n",
    "\n",
    "with torch.no_grad():\n",
    "    image_features = model.encode_image(image)\n",
    "    text_features = model.encode_text(text)\n",
    "    logits_per_image, logits_per_text = model(image, text)\n",
    "    probs = logits_per_image.softmax(dim=-1).cpu().numpy()\n",
    "\n",
    "traced = torch.jit.trace(model, (image, text))"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "26f7dcff",
   "metadata": {},
   "source": [
    "# 2. Export TextEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8f89976b",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "from collections import OrderedDict\n",
    "\n",
    "class ResidualAttentionBlock(nn.Module):\n",
    "    def __init__(self, d_model: int, n_head: int, attn_mask: torch.Tensor = None):\n",
    "        super().__init__()\n",
    "\n",
    "        self.attn = nn.MultiheadAttention(d_model, n_head)\n",
    "        self.ln_1 = LayerNorm(d_model)\n",
    "        self.mlp = nn.Sequential(OrderedDict([\n",
    "            (\"c_fc\", nn.Linear(d_model, d_model * 4)),\n",
    "            (\"gelu\", QuickGELU()),\n",
    "            (\"c_proj\", nn.Linear(d_model * 4, d_model))\n",
    "        ]))\n",
    "        self.ln_2 = LayerNorm(d_model)\n",
    "        self.attn_mask = attn_mask\n",
    "\n",
    "    def attention(self, x: torch.Tensor):\n",
    "        self.attn_mask = self.attn_mask.to(dtype=x.dtype, device=x.device) if self.attn_mask is not None else None\n",
    "        return self.attn(x, x, x, need_weights=False, attn_mask=self.attn_mask)[0]\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        x = x + self.attention(self.ln_1(x))\n",
    "        x = x + self.mlp(self.ln_2(x))\n",
    "        return x\n",
    "    \n",
    "class Transformer(nn.Module):\n",
    "    def __init__(self, width: int, layers: int, heads: int, attn_mask: torch.Tensor = None):\n",
    "        super().__init__()\n",
    "        self.width = width\n",
    "        self.layers = layers\n",
    "        self.resblocks = nn.Sequential(*[ResidualAttentionBlock(width, heads, attn_mask) for _ in range(layers)])\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        return self.resblocks(x)\n",
    "\n",
    "class LayerNorm(nn.LayerNorm):\n",
    "    \"\"\"Subclass torch's LayerNorm to handle fp16.\"\"\"\n",
    "\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        orig_type = x.dtype\n",
    "        ret = super().forward(x.type(torch.float32))\n",
    "        return ret.type(orig_type)\n",
    "\n",
    "class QuickGELU(nn.Module):\n",
    "    def forward(self, x: torch.Tensor):\n",
    "        return x * torch.sigmoid(1.702 * x)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c87abd71",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch.nn as nn\n",
    "\n",
    "class TextEncoder(nn.Module):\n",
    "    def __init__(self,\n",
    "                 embed_dim: int,\n",
    "                 # text\n",
    "                 context_length: int,\n",
    "                 vocab_size: int,\n",
    "                 transformer_width: int,\n",
    "                 transformer_heads: int,\n",
    "                 transformer_layers: int\n",
    "                 ):\n",
    "        super().__init__()\n",
    "\n",
    "        self.context_length = context_length\n",
    "\n",
    "        self.transformer = Transformer(\n",
    "                width=transformer_width,\n",
    "                layers=transformer_layers,\n",
    "                heads=transformer_heads,\n",
    "                attn_mask=self.build_attention_mask()\n",
    "        )\n",
    "\n",
    "        self.vocab_size = vocab_size\n",
    "        self.token_embedding = nn.Embedding(vocab_size, transformer_width)\n",
    "        self.positional_embedding = nn.Parameter(torch.empty(self.context_length, transformer_width))\n",
    "        self.ln_final = LayerNorm(transformer_width)\n",
    "\n",
    "        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))\n",
    "        self.temperature = nn.Parameter(torch.tensor(0.07))\n",
    "\n",
    "        self.text_projection = nn.Parameter(torch.empty(transformer_width, embed_dim))\n",
    "\n",
    "        print(f\"text_projection shape: {self.text_projection.shape}\")\n",
    "        self.dtype = torch.float32\n",
    "\n",
    "        self.initialize_parameters()\n",
    "    \n",
    "    def initialize_parameters(self):\n",
    "        nn.init.normal_(self.token_embedding.weight, std=0.02)\n",
    "        nn.init.normal_(self.positional_embedding, std=0.01)\n",
    "\n",
    "        proj_std = (self.transformer.width ** -0.5) * ((2 * self.transformer.layers) ** -0.5)\n",
    "        attn_std = self.transformer.width ** -0.5\n",
    "        fc_std = (2 * self.transformer.width) ** -0.5\n",
    "        for block in self.transformer.resblocks:\n",
    "            nn.init.normal_(block.attn.in_proj_weight, std=attn_std)\n",
    "            nn.init.normal_(block.attn.out_proj.weight, std=proj_std)\n",
    "            nn.init.normal_(block.mlp.c_fc.weight, std=fc_std)\n",
    "            nn.init.normal_(block.mlp.c_proj.weight, std=proj_std)\n",
    "\n",
    "        if self.text_projection is not None:\n",
    "            nn.init.normal_(self.text_projection, std=self.transformer.width ** -0.5)\n",
    "        else:\n",
    "            nn.init.normal_(self.text_projection, std=self.custom_text_config['text_rep_size'] ** -0.5)\n",
    "\n",
    "    def build_attention_mask(self):\n",
    "        # lazily create causal attention mask, with full attention between the vision tokens\n",
    "        # pytorch uses additive attention mask; fill with -inf\n",
    "        mask = torch.empty(self.context_length, self.context_length)\n",
    "        mask.fill_(float(\"-inf\"))\n",
    "        mask.triu_(1)  # zero out the lower diagonal\n",
    "        return mask\n",
    "\n",
    "    def forward(self, text):\n",
    "        # print(f'text: {text}')\n",
    "        x = self.token_embedding(text).type(self.dtype)  # [batch_size, n_ctx, d_model]\n",
    "\n",
    "        x = x + self.positional_embedding.type(self.dtype)\n",
    "        x = x.permute(1, 0, 2)  # NLD -> LND\n",
    "        x = self.transformer(x)\n",
    "        x = x.permute(1, 0, 2)  # LND -> NLD\n",
    "        x = self.ln_final(x).type(self.dtype)\n",
    "        # x.shape = [batch_size, n_ctx, transformer.width]\n",
    "        # take features from the eot embedding (eot_token is the highest number in each sequence)\n",
    "        x = x[torch.arange(x.shape[0]), text.argmax(dim=-1)] @ self.text_projection\n",
    "\n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c018fe96",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_encoder = TextEncoder(embed_dim=512, context_length=77, vocab_size=49408, \n",
    "                           transformer_width=512, transformer_heads=8, transformer_layers=12)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "658075c8",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_encoder.load_state_dict(model.state_dict(), strict=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "af2e63ed",
   "metadata": {},
   "outputs": [],
   "source": [
    "import coremltools as ct\n",
    "\n",
    "text_encoder.eval()\n",
    "\n",
    "example_input = clip.tokenize(\"a diagram\").to(device)\n",
    "traced_model = torch.jit.trace(text_encoder, example_input)\n",
    "out = traced_model(example_input)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "faa8e249",
   "metadata": {},
   "outputs": [],
   "source": [
    "max_seq_length = 77\n",
    "\n",
    "text_encoder_model = ct.convert(\n",
    "            traced_model,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[ct.TensorType(name=\"prompt\",\n",
    "                                 shape=[1,max_seq_length],\n",
    "                                 dtype=np.int32)],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32)],\n",
    "#             compute_units=ct.ComputeUnit[args.compute_unit],\n",
    "            # skip_model_load=True,\n",
    "        )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b4c04977",
   "metadata": {},
   "outputs": [],
   "source": [
    "text_encoder_model.save(\"TextEncoder_float32.mlpackage\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "617e4e6b",
   "metadata": {},
   "source": [
    "## Validate export  precision"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "fd6af02a",
   "metadata": {},
   "outputs": [],
   "source": [
    "import coremltools as ct\n",
    "\n",
    "# Load the model\n",
    "model = ct.models.MLModel('TextEncoder_float32.mlpackage')\n",
    "text = clip.tokenize(\"a diagram\").to(device)\n",
    "predictions = model.predict({'prompt': text})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "c29d0a98",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"PyTorch TextEncoder ckpt out for \\\"a diagram\\\":\\n>>>\", out[0, :10])\n",
    "print(\"\\nCoreML TextEncoder ckpt out for \\\"a diagram\\\":\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "3c0d9c70",
   "metadata": {},
   "source": [
    "You can see that there is some loss in precision, but it is still acceptable."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ca182b4a",
   "metadata": {},
   "source": [
    "# 3. Export ImageEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4dd4e99f",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import clip\n",
    "import coremltools as ct\n",
    "import numpy as np\n",
    "from PIL import Image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "68521589",
   "metadata": {},
   "outputs": [],
   "source": [
    "device=\"cpu\"\n",
    "model, preprocess = clip.load(\"ViT-B/32\", device=device)\n",
    "i = Image.open(\"IMG_3628.jpg\")\n",
    "image_orig = preprocess(i).unsqueeze(0).to(device)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "304ae7b0",
   "metadata": {},
   "outputs": [],
   "source": [
    "traced_image_only = torch.jit.trace(model.visual, image_orig)\n",
    "out = traced_image_only(image_orig)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9bab986",
   "metadata": {},
   "outputs": [],
   "source": [
    "import coremltools as ct\n",
    "# Set the image scale and bias for input image preprocessing\n",
    "scale = 1/(0.2685697*255.0)\n",
    "bias = [- 0.48145466/(0.26862954) , - 0.4578275/(0.26130258), - 0.40821073/(0.27577711)]\n",
    "\n",
    "# imgPIL = Image.open(\"4111670639918_.pic.png\")\n",
    "\n",
    "image_input_scale = ct.ImageType(name=\"colorImage\",\n",
    "                           color_layout=ct.colorlayout.RGB,\n",
    "                           shape=image_orig.shape,\n",
    "                           scale=scale, bias=bias)\n",
    "\n",
    "\n",
    "image_encoder_model = ct.convert(\n",
    "            traced_image_only,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[image_input_scale],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32)],\n",
    "        )\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cb6623a8",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_encoder_model.save(\"ImageEncoder_float32.mlpackage\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "f3c5008e",
   "metadata": {},
   "source": [
    "## Validate export"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "759bb57d",
   "metadata": {},
   "outputs": [],
   "source": [
    "import coremltools as ct\n",
    "\n",
    "# Load the model\n",
    "image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')\n",
    "\n",
    "from torchvision import transforms\n",
    "imgPIL = Image.open(\"IMG_3628.jpg\")\n",
    "imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)\n",
    "predictions = image_encoder.predict({'colorImage': imgPIL})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "4a8c2a34",
   "metadata": {},
   "outputs": [],
   "source": [
    "print(\"PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\\n>>>\", out[0, :10])\n",
    "print(\"\\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "2b71d068",
   "metadata": {},
   "source": [
    "This time <span style='color:red'> the precision error is larger.</span> This may be caused by the wrong norm. "
   ]
  },
  {
   "cell_type": "markdown",
   "id": "aac310d4",
   "metadata": {},
   "source": [
    "## What if no norm?"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f24ec713",
   "metadata": {},
   "outputs": [],
   "source": [
    "image_input_scale = ct.ImageType(name=\"colorImage\",\n",
    "                           color_layout=ct.colorlayout.RGB,\n",
    "                           shape=image_orig.shape)\n",
    "\n",
    "\n",
    "image_encoder_model = ct.convert(\n",
    "            traced_image_only,\n",
    "            convert_to=\"mlprogram\",\n",
    "            minimum_deployment_target=ct.target.iOS16,\n",
    "            inputs=[image_input_scale],\n",
    "            outputs=[ct.TensorType(name=\"embOutput\", dtype=np.float32)],\n",
    "        )\n",
    "\n",
    "image_encoder_model.save(\"ImageEncoder_float32.mlpackage\")\n",
    "\n",
    "image_encoder = ct.models.MLModel('ImageEncoder_float32.mlpackage')\n",
    "\n",
    "from torchvision import transforms\n",
    "imgPIL = Image.open(\"IMG_3628.jpg\")\n",
    "imgPIL = imgPIL.resize((224, 224), Image.BICUBIC)\n",
    "predictions = image_encoder.predict({'colorImage': imgPIL})\n",
    "\n",
    "print(\"PyTorch ImageEncoder ckpt out for IMG_3628.jpg:\\n>>>\", out[0, :10])\n",
    "print(\"\\nCoreML ImageEncoder ckpt out for IMG_3628.jpg:\\n>>>\", predictions['embOutput'][0, :10])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "386daa59",
   "metadata": {},
   "source": [
    "**The error is even worse.**"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.15"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
